"""
CUDA_VISIBLE_DEVICES=0 python eval_FVD.py
"""
dataset_dir = "./../dataset"
test_id = "rebuttal/video"
gt_id = "GT"

import sys
import cv2
import numpy as np
from PIL import Image
import os 
import PIL

H = W = 256
num_frames = 14

def export_to_video(
    video_frames, output_video_path, fps
):
    if output_video_path is None:
        output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name

    if isinstance(video_frames[0], np.ndarray):
        video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames]

    elif isinstance(video_frames[0], PIL.Image.Image):
        video_frames = [np.array(frame) for frame in video_frames]

    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    h, w, c = video_frames[0].shape
    video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h))
    for i in range(len(video_frames)):
        img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
        video_writer.write(img)

def extract_frames(file_id):
    extract_dir = os.path.join(dataset_dir, file_id)
    save_dir = os.path.join(dataset_dir, file_id+"_fvd")

    if os.path.exists(save_dir):
        print(save_dir+" exists!!")
        return save_dir
    os.makedirs(save_dir, exist_ok=True)

    file_list = os.listdir(extract_dir)
    assert(len(file_list)==329 or len(file_list)==330)

    for i, f in enumerate(file_list):
        if not f.endswith(".mp4"):
            continue
        print("reading:", i, f)
        videopath = os.path.join(extract_dir, f)
        cam = cv2.VideoCapture(videopath)
        ctr = 0
        pil_images = []
        save_path = os.path.join(save_dir, f)
        while ctr < num_frames:
            if True:
                _, frame = cam.read()
                frame = cv2.resize(frame, (W, H))
                pil_images.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
                ctr += 1 
            else:
                break
        cam.release()
        export_to_video(pil_images, save_path, fps=7)
    return save_dir
    
save_dir1 = extract_frames(test_id)
save_dir2 = extract_frames(gt_id)

from cdfvd import fvd
#standard one
evaluator = fvd.cdfvd('i3d', n_real = "full", n_fake="full", ckpt_path=None)
evaluator.compute_real_stats(evaluator.load_videos(video_info = save_dir2, data_type='video_folder', resolution=H, sequence_length=14))
evaluator.compute_fake_stats(evaluator.load_videos(video_info = save_dir1, data_type='video_folder', resolution=W, sequence_length=14))
score = evaluator.compute_fvd_from_stats()

#evaluate FID
print("score:", score)
